{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3ecfe074-b6cb-4389-97d8-29e1b41f62e7",
   "metadata": {},
   "source": [
    "# How to train a new model on a new dataset from scratch in CoreNet\n",
    "\n",
    "The purpose of this tutorial is to familiarize you with experimenting with new models and datasets from scratch in CoreNet. We implement a simple classification model on CIFAR10 dataset, then demonstrate how to launch the training and evaluation.\n",
    "\n",
    "Let's first make sure our current working directory is the root folder of the repository."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2c79508f-0138-417d-804c-7b0006439b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "if os.getcwd().endswith(\"tutorials\"):\n",
    "    os.chdir(\"..\")\n",
    "\n",
    "assert os.path.exists(\n",
    "    \"corenet\"\n",
    "), f\"We should be in the root repository folder, but we are in {os.getcwd()}\"\n",
    "\n",
    "! mkdir -p projects/playground_cifar10/classification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24e3ce2d-2916-4319-8c0c-73e808f0cddd",
   "metadata": {},
   "source": [
    "## Create new training configuration\n",
    "\n",
    "Let's start by creating a YAML training configuration in the `projects/playground_cifar10` folder. You can find more training recipes in `projects/` folder. \n",
    "\n",
    "Note: The following jupyter notebook cells in this tutorial start with `%%file <path>`. This header instructs jupyter to write the content of the cell to the specified `<path>`, when executed. We leverage this feature to generate the YAML and Python files."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "68ef9fe9-d6a7-4c15-a4d5-1f1b29ed9292",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting projects/playground_cifar10/classification/cifar10.yaml\n"
     ]
    }
   ],
   "source": [
    "%%file projects/playground_cifar10/classification/cifar10.yaml\n",
    "\n",
    "common:\n",
    "    log_freq: 2000                 # Log the training metrics every 2000 iterations.\n",
    "\n",
    "dataset:\n",
    "    category: classification\n",
    "    name: \"cifar10\"                # We'll register the \"cifar10\" name at DATASET_REGISTRY later in this tutorial.\n",
    "\n",
    "    # The `corenet-train` entrypoint uses train_batch_size0 and val_batch_size0 values to construct \n",
    "    # training/validation batches during training. The `corenet-eval` entrypoint uses eval_batch_size0 to \n",
    "    # construct batches during evaluation (ie test).\n",
    "    #\n",
    "    # The effective batch size is: num_nodes x num_gpus x train_batch_size0\n",
    "    train_batch_size0: 4\n",
    "    val_batch_size0: 4\n",
    "    eval_batch_size0: 1\n",
    "\n",
    "    workers: 2\n",
    "    persistent_workers: true\n",
    "    pin_memory: true\n",
    "\n",
    "model:\n",
    "    classification:\n",
    "        name: \"two_layer\"          # We'll register the \"two_layer\" name at MODEL_REGISTRY later in this tutorial.\n",
    "        n_classes: 10\n",
    "\n",
    "    layer:\n",
    "        # Weight initialization parameters:\n",
    "        conv_init: \"kaiming_normal\"\n",
    "        linear_init: \"trunc_normal\"\n",
    "        linear_init_std_dev: 0.02\n",
    "\n",
    "\n",
    "sampler:\n",
    "    name: batch_sampler\n",
    "\n",
    "    # The following dimensions will be passed to the dataset.__get__ method, and the dataset produces samples \n",
    "    # cropped and resized to the requested dimensions. \n",
    "    bs:\n",
    "        crop_size_width: 32\n",
    "        crop_size_height: 32\n",
    "\n",
    "loss:\n",
    "    category: classification\n",
    "    classification:\n",
    "        name: cross_entropy       # The implemention is available in \"corenet/loss_fn/\" folder.\n",
    "\n",
    "optim:\n",
    "    name: sgd\n",
    "    sgd:\n",
    "        momentum: 0.9\n",
    "\n",
    "scheduler:\n",
    "    name: fixed                    # The implementation is available in \"corenet/optims/scheduler/\" folder.\n",
    "    max_epochs: 2\n",
    "    fixed:\n",
    "        lr: 0.001                  # Fixed Learning Rate\n",
    "\n",
    "stats:\n",
    "  val: [\"loss\", \"top1\"]            # Metrics to log\n",
    "  train: [\"loss\", \"top1\"]\n",
    "  checkpoint_metric: top1          # Assigns a checkpoint to results/checkpoint_best.pt\n",
    "  checkpoint_metric_max: true\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "850c7db0-8231-4796-a715-0bf29a1528a4",
   "metadata": {},
   "source": [
    "## Register model and dataset classes\n",
    "\n",
    "Now, let's define the \"cifar10\" dataset and \"two_layer\" model that we have used in the \n",
    "above config. You can find more datasets in `corenet/data/datasets` and more models in\n",
    "`corenet/modeling/models`, where directories represent tasks (e.g. \"classification\")."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "bcda60b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting corenet/data/datasets/classification/playground_dataset.py\n"
     ]
    }
   ],
   "source": [
    "%%file corenet/data/datasets/classification/playground_dataset.py\n",
    "\n",
    "from argparse import Namespace\n",
    "from typing import Any, Dict, Tuple\n",
    "\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "\n",
    "from corenet.data.datasets import DATASET_REGISTRY\n",
    "from corenet.data.datasets.dataset_base import BaseDataset\n",
    "\n",
    "\n",
    "@DATASET_REGISTRY.register(name=\"cifar10\", type=\"classification\")\n",
    "class Cifar10(BaseDataset):\n",
    "    CLASS_NAMES = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "    \n",
    "    def __init__(self, opts: Namespace, **kwargs) -> None:\n",
    "        super().__init__(opts, **kwargs)\n",
    "        self._torchvision_dataset = torchvision.datasets.CIFAR10(\n",
    "            \"/tmp/cifar10_cache\",\n",
    "            train=self.is_training,\n",
    "            download=True,\n",
    "        )\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return len(self._torchvision_dataset)\n",
    "\n",
    "    def __getitem__(self, sample_size_and_index: Tuple[int]) -> Dict[str, Any]:\n",
    "        # In CoreNet, not only does the sampler determine the index of the samples, but\n",
    "        # also the sampler determines the crop size dynamically for each batch. This\n",
    "        # allows samplers to train multi-scale models more efficiently.\n",
    "        # See: corenet/data/sampler/variable_batch_sampler.py\n",
    "        (crop_size_h, crop_size_w, sample_index) = sample_size_and_index\n",
    "\n",
    "        img, target = self._torchvision_dataset[sample_index]\n",
    "\n",
    "        transform = transforms.Compose(\n",
    "            [\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n",
    "                transforms.Resize(size=(crop_size_h, crop_size_w)),\n",
    "            ]\n",
    "        )\n",
    "        img = transform(img)\n",
    "        return {\n",
    "            \"samples\": img,\n",
    "            \"targets\": target,\n",
    "        }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3b53c116",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting corenet/modeling/models/classification/playground_model.py\n"
     ]
    }
   ],
   "source": [
    "%%file corenet/modeling/models/classification/playground_model.py\n",
    "\n",
    "import argparse\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch import nn\n",
    "\n",
    "from corenet.modeling.models import MODEL_REGISTRY\n",
    "from corenet.modeling.models.base_model import BaseAnyNNModel\n",
    "\n",
    "\n",
    "@MODEL_REGISTRY.register(\"two_layer\", type=\"classification\")\n",
    "class Net(BaseAnyNNModel):\n",
    "    \"\"\"A simple 2-layer CNN, inspired by https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html\"\"\"\n",
    "\n",
    "    def __init__(self, opts: argparse.Namespace) -> None:\n",
    "        super().__init__(opts)\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "        self.reset_parameters(opts)  # Initialize the weights\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c51c2bda",
   "metadata": {},
   "source": [
    "### Launching the training\n",
    "\n",
    "You can train the model by specifying the yaml config file that we created earlier in \n",
    "this tutorial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3c63e142",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/m_sekhavat/miniconda3/envs/corenet/lib/python3.10/site-packages/turicreate/_deps/__init__.py:9: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives\n",
      "  from distutils.version import StrictVersion as _StrictVersion\n",
      "2024-04-18 00:19:55 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Random seeds are set to 0\n",
      "2024-04-18 00:19:55 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Using PyTorch version 2.2.1+cu121\n",
      "2024-04-18 00:19:55 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Available GPUs: 1\n",
      "2024-04-18 00:19:55 - \u001b[34m\u001b[1mLOGS   \u001b[0m - CUDNN is enabled\n",
      "2024-04-18 00:19:56 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Setting --ddp.world-size the same as the number of available gpus.\n",
      "2024-04-18 00:19:56 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Directory exists at: results/run_1\n",
      "2024-04-18 00:19:59 - \u001b[32m\u001b[1mINFO   \u001b[0m - distributed init (rank 0): tcp://m-sekhavat-dev2:30786\n",
      "Files already downloaded and verified\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Training dataset details are given below\n",
      "Cifar10(\n",
      "\troot= \n",
      "\tis_training=True \n",
      "\tnum_samples=50000\n",
      ")\n",
      "Files already downloaded and verified\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Validation dataset details are given below\n",
      "Cifar10(\n",
      "\troot= \n",
      "\tis_training=False \n",
      "\tnum_samples=10000\n",
      ")\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Training sampler details: BatchSamplerDDP(\n",
      "\t num_repeat=1\n",
      "\t trunc_rep_aug=False\n",
      "\t sharding=False\n",
      "\t disable_shuffle_sharding=False\n",
      "\tbase_im_size=(h=32, w=32)\n",
      "\tbase_batch_size=4\n",
      ")\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Validation sampler details: BatchSamplerDDP(\n",
      "\t num_repeat=1\n",
      "\t trunc_rep_aug=False\n",
      "\t sharding=False\n",
      "\t disable_shuffle_sharding=False\n",
      "\tbase_im_size=(h=32, w=32)\n",
      "\tbase_batch_size=4\n",
      ")\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Number of data workers: 2\n",
      "2024-04-18 00:20:01 - \u001b[32m\u001b[1mINFO   \u001b[0m - Trainable parameters: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias']\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mModel\u001b[0m\n",
      "Net(\n",
      "  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n",
      "\u001b[31m=================================================================\u001b[0m\n",
      "                                Net Summary\n",
      "\u001b[31m=================================================================\u001b[0m\n",
      "Total parameters     =    0.062 M\n",
      "Total trainable parameters =    0.062 M\n",
      "\n",
      "2024-04-18 00:20:01 - \u001b[33m\u001b[1mWARNING\u001b[0m - Profiling not available, dummy_input_and_label not implemented for this model.\n",
      "2024-04-18 00:20:01 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Using DistributedDataParallel.\n",
      "2024-04-18 00:20:02 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mLoss function\u001b[0m\n",
      "CrossEntropy(\n",
      "\t ignore_idx=-1\n",
      "\t class_weighting=False\n",
      "\t label_smoothing=0.0\n",
      ")\n",
      "2024-04-18 00:20:02 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mOptimizer\u001b[0m\n",
      "SGDOptimizer (\n",
      "\t dampening: [0]\n",
      "\t differentiable: [False]\n",
      "\t foreach: [None]\n",
      "\t lr: [0.1]\n",
      "\t maximize: [False]\n",
      "\t momentum: [0.9]\n",
      "\t nesterov: [False]\n",
      "\t weight_decay: [4e-05]\n",
      ")\n",
      "2024-04-18 00:20:02 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Max. epochs for training: 2\n",
      "2024-04-18 00:20:02 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mLearning rate scheduler\u001b[0m\n",
      "FixedLRScheduler(\n",
      "\tlr=0.001\n",
      " )\n",
      "2024-04-18 00:20:02 - \u001b[32m\u001b[1mINFO   \u001b[0m - Configuration file is stored here: \u001b[36mresults/run_1/config.yaml\u001b[0m\n",
      "\u001b[31m===========================================================================\u001b[0m\n",
      "2024-04-18 00:20:04 - \u001b[32m\u001b[1mINFO   \u001b[0m - Training epoch 0\n",
      "2024-04-18 00:20:13 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [       1/10000000], loss: 2.353, top1: 0.0, LR: [0.001], Avg. batch load time: 8.939, Elapsed time:  9.60\n",
      "2024-04-18 00:20:23 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    2001/10000000], loss: 2.1819, top1: 16.954, LR: [0.001], Avg. batch load time: 0.005, Elapsed time: 18.86\n",
      "2024-04-18 00:20:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    4001/10000000], loss: 2.0067, top1: 24.2439, LR: [0.001], Avg. batch load time: 0.003, Elapsed time: 27.71\n",
      "2024-04-18 00:20:40 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    6001/10000000], loss: 1.8889, top1: 29.3868, LR: [0.001], Avg. batch load time: 0.002, Elapsed time: 36.57\n",
      "2024-04-18 00:20:50 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    8001/10000000], loss: 1.8088, top1: 32.6272, LR: [0.001], Avg. batch load time: 0.002, Elapsed time: 45.68\n",
      "2024-04-18 00:20:58 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [   10001/10000000], loss: 1.7547, top1: 34.9515, LR: [0.001], Avg. batch load time: 0.001, Elapsed time: 54.54\n",
      "2024-04-18 00:21:07 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [   12001/10000000], loss: 1.7098, top1: 36.7949, LR: [0.001], Avg. batch load time: 0.001, Elapsed time: 63.47\n",
      "2024-04-18 00:21:10 - \u001b[34m\u001b[1mLOGS   \u001b[0m - *** Training summary for epoch 0\n",
      "\t loss=1.6997 || top1=37.214\n",
      "2024-04-18 00:21:19 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [       4/   10000], loss: 0.9705, top1: 50.0, LR: [0.001], Avg. batch load time: 0.000, Elapsed time:  7.04\n",
      "2024-04-18 00:21:24 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    8004/   10000], loss: 1.4266, top1: 48.2634, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 12.46\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - *** Validation summary for epoch 0\n",
      "\t loss=1.4324 || top1=48.25\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Best checkpoint with score 48.25 saved at results/run_1/checkpoint_best.pt\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Deleting checkpoint: results/run_1/checkpoint_score_1.2353.pt\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Averaging checkpoints: ['checkpoint_score_1.4228.pt', 'checkpoint_score_48.2500.pt', 'checkpoint_score_49.2900.pt', 'checkpoint_score_49.8700.pt', 'checkpoint_score_54.7400.pt']\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Averaged checkpoint saved at: results/run_1/checkpoint_avg.pt\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Last training checkpoint is saved at: results/run_1/training_checkpoint_last.pt\n",
      "2024-04-18 00:21:26 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Last checkpoint's model state is saved at: results/run_1/checkpoint_last.pt\n",
      "\u001b[31m===========================================================================\u001b[0m\n",
      "2024-04-18 00:21:28 - \u001b[32m\u001b[1mINFO   \u001b[0m - Training epoch 1\n",
      "2024-04-18 00:21:28 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   12501/10000000], loss: 1.2167, top1: 50.0, LR: [0.001], Avg. batch load time: 0.042, Elapsed time:  0.05\n",
      "2024-04-18 00:21:37 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   14501/10000000], loss: 1.424, top1: 49.1004, LR: [0.001], Avg. batch load time: 0.000, Elapsed time:  8.88\n",
      "2024-04-18 00:21:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   16501/10000000], loss: 1.4023, top1: 49.7438, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 17.92\n",
      "2024-04-18 00:21:55 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   18501/10000000], loss: 1.3817, top1: 50.4916, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 26.66\n",
      "2024-04-18 00:22:03 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   20501/10000000], loss: 1.3659, top1: 51.1499, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 35.22\n",
      "2024-04-18 00:22:12 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   22501/10000000], loss: 1.3521, top1: 51.6673, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 43.97\n",
      "2024-04-18 00:22:21 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [   24501/10000000], loss: 1.3378, top1: 52.1477, LR: [0.001], Avg. batch load time: 0.000, Elapsed time: 52.74\n",
      "2024-04-18 00:22:23 - \u001b[34m\u001b[1mLOGS   \u001b[0m - *** Training summary for epoch 1\n",
      "\t loss=1.3375 || top1=52.196\n",
      "2024-04-18 00:22:25 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [       4/   10000], loss: 1.1664, top1: 25.0, LR: [0.001], Avg. batch load time: 0.000, Elapsed time:  0.01\n",
      "2024-04-18 00:22:30 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   1 [    8004/   10000], loss: 1.2198, top1: 56.6217, LR: [0.001], Avg. batch load time: 0.000, Elapsed time:  5.44\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - *** Validation summary for epoch 1\n",
      "\t loss=1.2256 || top1=56.54\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Best checkpoint with score 56.54 saved at results/run_1/checkpoint_best.pt\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Deleting checkpoint: results/run_1/checkpoint_score_1.4228.pt\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Averaging checkpoints: ['checkpoint_score_48.2500.pt', 'checkpoint_score_49.2900.pt', 'checkpoint_score_49.8700.pt', 'checkpoint_score_54.7400.pt', 'checkpoint_score_56.5400.pt']\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Averaged checkpoint saved at: results/run_1/checkpoint_avg.pt\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Last training checkpoint is saved at: results/run_1/training_checkpoint_last.pt\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Last checkpoint's model state is saved at: results/run_1/checkpoint_last.pt\n",
      "2024-04-18 00:22:32 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Training took 00:02:30.57\n"
     ]
    }
   ],
   "source": [
    "! corenet-train --common.config-file projects/playground_cifar10/classification/cifar10.yaml"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee65265e-98ec-43c5-8dd7-1e07501b6522",
   "metadata": {},
   "source": [
    "By running the above command, you should observe that the model's validation accuracy has increased from ~10% to 54.74%.\n",
    "\n",
    "We can find the saved checkpoints in the results folder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b1b7511a-adc1-4d89-80e3-f1ea71a94c14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[01;34mresults/\u001b[00m\n",
      "├── \u001b[01;34mclassification_results\u001b[00m\n",
      "│   └── \u001b[01;34mrun_1\u001b[00m\n",
      "├── logs.txt\n",
      "└── \u001b[01;34mrun_1\u001b[00m\n",
      "    ├── checkpoint_avg.pt\n",
      "    ├── checkpoint_best.pt\n",
      "    ├── checkpoint_last.pt\n",
      "    ├── checkpoint_score_48.2500.pt\n",
      "    ├── checkpoint_score_49.2900.pt\n",
      "    ├── checkpoint_score_49.8700.pt\n",
      "    ├── checkpoint_score_54.7400.pt\n",
      "    ├── checkpoint_score_56.5400.pt\n",
      "    ├── config.yaml\n",
      "    └── training_checkpoint_last.pt\n",
      "\n",
      "3 directories, 11 files\n"
     ]
    }
   ],
   "source": [
    "! tree results/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39e768f8-ad90-409c-8eaf-f3caf02a06ba",
   "metadata": {},
   "source": [
    "### Launching the evaluation\n",
    "CoreNet follows the standard practice of splitting datasets into `train`, `val`, and `test` splits. The `corenet-train` entrypoint that was used in the previous section consumes `train` and `val` splits for training the model and finding the best checkpoint. In this section, we will use `corenet-eval` entrypoint, that evaluates a model checkpoint on the `test` split of a dataset.\n",
    "\n",
    "In the below command, we set `CUDA_VISIBLE_DEVICES=0` environment variable, as a good reproducibility practice for evaluation to use only 1 gpu for inference. Otherwise, CoreNet uses all available GPUs by default."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "72b847fb-0b2c-4d6a-a045-6cdb52ced476",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/m_sekhavat/miniconda3/envs/corenet/lib/python3.10/site-packages/turicreate/_deps/__init__.py:9: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives\n",
      "  from distutils.version import StrictVersion as _StrictVersion\n",
      "2024-04-18 00:22:40 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Random seeds are set to 0\n",
      "2024-04-18 00:22:40 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Using PyTorch version 2.2.1+cu121\n",
      "2024-04-18 00:22:40 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Available GPUs: 1\n",
      "2024-04-18 00:22:40 - \u001b[34m\u001b[1mLOGS   \u001b[0m - CUDNN is enabled\n",
      "2024-04-18 00:22:41 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Setting --ddp.world-size the same as the number of available gpus.\n",
      "2024-04-18 00:22:41 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Directory exists at: results/run_1\n",
      "2024-04-18 00:22:45 - \u001b[32m\u001b[1mINFO   \u001b[0m - distributed init (rank 0): tcp://m-sekhavat-dev2:30786\n",
      "Files already downloaded and verified\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Evaluation dataset details: \n",
      "Cifar10(\n",
      "\troot= \n",
      "\tis_training=False \n",
      "\tnum_samples=10000\n",
      ")\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Evaluation sampler details: BatchSamplerDDP(\n",
      "\t num_repeat=1\n",
      "\t trunc_rep_aug=False\n",
      "\t sharding=False\n",
      "\t disable_shuffle_sharding=False\n",
      "\tbase_im_size=(h=32, w=32)\n",
      "\tbase_batch_size=1\n",
      ")\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Pretrained weights are loaded from results/run_1/checkpoint_best.pt\n",
      "2024-04-18 00:22:46 - \u001b[32m\u001b[1mINFO   \u001b[0m - Trainable parameters: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias']\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mModel\u001b[0m\n",
      "Net(\n",
      "  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
      "  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
      "  (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
      "  (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
      "  (fc3): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n",
      "\u001b[31m=================================================================\u001b[0m\n",
      "                                Net Summary\n",
      "\u001b[31m=================================================================\u001b[0m\n",
      "Total parameters     =    0.062 M\n",
      "Total trainable parameters =    0.062 M\n",
      "\n",
      "2024-04-18 00:22:46 - \u001b[33m\u001b[1mWARNING\u001b[0m - Profiling not available, dummy_input_and_label not implemented for this model.\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Using DistributedDataParallel.\n",
      "2024-04-18 00:22:46 - \u001b[34m\u001b[1mLOGS   \u001b[0m - \u001b[36mLoss function\u001b[0m\n",
      "CrossEntropy(\n",
      "\t ignore_idx=-1\n",
      "\t class_weighting=False\n",
      "\t label_smoothing=0.0\n",
      ")\n",
      "2024-04-18 00:22:53 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [       1/   10000], loss: 1.137, top1: 0.0, LR: 0.000000, Avg. batch load time: 0.000, Elapsed time:  7.27\n",
      "2024-04-18 00:22:59 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    2001/   10000], loss: 1.2002, top1: 56.022, LR: 0.000000, Avg. batch load time: 0.000, Elapsed time: 12.85\n",
      "2024-04-18 00:23:04 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    4001/   10000], loss: 1.2143, top1: 55.886, LR: 0.000000, Avg. batch load time: 0.000, Elapsed time: 18.20\n",
      "2024-04-18 00:23:10 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    6001/   10000], loss: 1.2187, top1: 56.3406, LR: 0.000000, Avg. batch load time: 0.000, Elapsed time: 23.91\n",
      "2024-04-18 00:23:16 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Epoch:   0 [    8001/   10000], loss: 1.22, top1: 56.6179, LR: 0.000000, Avg. batch load time: 0.000, Elapsed time: 29.44\n",
      "2024-04-18 00:23:21 - \u001b[34m\u001b[1mLOGS   \u001b[0m - *** Evaluation summary for epoch 0\n",
      "\t loss=1.2256 || top1=56.54\n",
      "2024-04-18 00:23:21 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Evaluation took 35.33877229690552 seconds\n"
     ]
    }
   ],
   "source": [
    "! CUDA_VISIBLE_DEVICES=0 corenet-eval \\\n",
    "    --common.config-file projects/playground_cifar10/classification/cifar10.yaml \\\n",
    "    --model.classification.pretrained results/run_1/checkpoint_best.pt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f75994b9-ff15-4195-9e2c-147be276a0fe",
   "metadata": {},
   "source": [
    "We observed the same `top1=54.74` result as we observed in the validation accuracy during training, because the current CIFAR10 implementation uses the same test set for validation and test. In order to differentiate between validation and test, you can access `self.mode` in the dataset, which is a member of `{\"train\", \"val\", \"test\"}`."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad735fc4-db86-400a-a546-362c3a54ac18",
   "metadata": {},
   "source": [
    "### Visualizing the classification results\n",
    "It's easy to load a model checkpoint and interact with it in notebook:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8e395824-a625-49ba-ba16-e63fe68a4172",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2024-04-18 15:52:19 - \u001b[34m\u001b[1mLOGS   \u001b[0m - Pretrained weights are loaded from results/run_1/checkpoint_best.pt\n",
      "2024-04-18 15:52:19 - \u001b[32m\u001b[1mINFO   \u001b[0m - Trainable parameters: ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJsElEQVR4nAXBaY+d110A8PM/53+e9e5zZ+bO4vGMHS+ZsZ04MWnqNimJBAoSQuIdUishKvVb8AZVID4Db/gCKELQFhqFhlQUVyZLG9eOd489M3eWuz7PfZazH34/+Pjzf3SehkEbuXUkIwSQeudPKDXIiAcDuEzZQGvrSEno1JHcWLBWS3VaLJ4bhWmoVHXoNV3tX1ciOTpmtck/+Y+Hrf7W+iBEQ2pLCPrA25pAzYMlShsUmtaNrJ8AVd4Xzs4IxCFrEcqFQaUmzoecXQ35WsBYK+E2fMWggbALYSPk48niYGMLvvrD773fRmlySqlygvrCEa9JGmCbY4syQikAjT0hznkGFBl6n4BTQKwnitFGI15tBFcC1s7Kgyp/Oc4kMcV0dKBV9eD5Q2A1uBcIllJCvKs9c8h6zlLhFp6QJEKOA0cq6yqgHCiRdqx0jbQXhz1tM0o8Q44M06TrfZGdfu6pbq5iK63y/HzE3eaS/871GrljzsoA0yTsMdr0hHtmHRTaLBjtI3a1VYQYa2sPHrEZYYuwRVErJF2E0Nup9QnDRdKKBktvoy/K2TPL8c9uL50Nc3Aek6TrPOFBHLE0YBGPExb4WttKaSAOwMe87bygxCfhJoOWJbXydRL2OaxEvAsera21OQ07lAZl6tZV4sX8qxvX4OuSNRKLaaMnDTGu9lQHSRpGxDPZCDrImtYaRMWDJqVd8ARpz9nAyhoA02gQ4crTrw91bS/e3MHIMTsEpgnvQmQCPFOqwqhfizk6yC3khpQGuo5ibcfeyjjYTMOOlBlH4Dz2xFkvhHklhXbWcc7Bg5LzF7/6zfDe82e39r77wx2Gx1KfSKEItgihwrc6653s+Dwq+ZhyaIRtcFJUM09KYDxGShnEYZOQ0qjcgwfGvK+8r5A1OYaOaqPK0NFBp/P0fx9U5PTWXzbDRFsyhaTZHFzSdP3o4ODLfx0ihSqOBiFfkcITm0ZBC5B44g2pKUjqnDNCEcmwydlyEEdAkKJzUHozrxhl6XKrF//2548PX/Zvvr+7vLxTCa/U7OhZ/uUnoyXdRCWF8zlJe8q7XnN7rXfdg56UT6SZOTPjDq1Xs/pEKZaEa41kK+AhccpRBeCXz23d+fbecDLNtDv97eGjL3Me3HNAnakM0dTxlXaEYhazlHmaGfoq0ILhGyEOTqYPsvIESBVj4GhVqqE3gwi9dzYKOpRFtV0IKc5mkzvf/E44r7xeP7dNpEUOk8lIGb20vm7qopxMsNVec4G0/tj5+4V4fJJd7CW3s/nRojzstNfazTess+VsWZcpiEE+jkyTdJb6LDYM5ycnk0ybpNHa7A/SZoNj/Nd/8+PffXHns08/pYxnZTmcTVEHzJEpA5XEq8a/zOr7rejGcu9yzFq2CEenOBnWs7Ou1oyh0Grh2aTRbbb6UXfp4unBZ2nU4azRSFt1XQat9K1b1ze3Bt98883+02eq0IvS4qw6A5hsLr8TxZ25uCclzifT+aGdvIpEZsBNS2GcDyiGWggpRZzGIYa7e9eMnGnLkqR5PDnrbbDnw1fXu2+ezSdHJ/uPnn5NqA9SFooOMrLEkFsSO7sKZXLv7v3y4DMxt3VVp0nabHZqYRhPMIy9NZRR5pzn0F9uFwvNI16qGoKm8tH65qDT61uPR8eHPDKMJ9S3aFVhv3UhTmPHXr16fu/uz86On0w48FoIIGRtZZBnuQNUjjYanSRtlqKqz467W6tJI0LW7S91wau6yA8Pgp0L2+9+96NzqwN4671PL/z68ZPnk9MROEuThCE0Hn9R/PKfv5rt54PuMiJKJdeWV7wn02yeF0UthHUuK4t5LTzyRrfLGOWc9/v9KA7Wt1c/+ouPfvoPf//hB+/n49na6vrf/t1PP/zoTxmTQBSejn5//zfq/v8MnQ4oEZTk3pLtja04CA/PzqI41c4zxkohaJho5zjyTq8NAEEU3v7wByWKsBm++da7nNvhwaMvf/5vN/74T/rnt957/8bwcHf4Yowf/9Odl3/wa8vnOt0l8CRCniZNDzA6O7HGFpVwhMY8CjlHZAmPdy5fvHhhkwIQQveuXRN+MVmciPr0+f4om54OF4/JE2BHASHm9RsbV6+ex1//58Hea3trywMHEEVprV1h/DybPdnfDxnr9lbCgCGi0QIQ4qX+O9+52e8lzjkpdVnXS51+Vo2KalJU47qcNjaiWfUEFHHaU5JcurSHV85f3lzbBIrtZjtK24vxZDg8WiwyxmMesNF8QoE2hQyCMOh037755vb5taqS3hlvbVEWlRScBVrlWmbO10HDSymdBaPJa9u7SZzgxavXrly5Nh6N5vk8H40tBkl3KStLoSuhhLVOyiqIUmHt+ubW8kq/FlWR180kouBLUdZKeKABT6NIeG+VEsZY7yGKejvnrt9/fBf//K9+BMAOf/GL6TyzQFcHm5ayZ08fGuustVWRhwHO8lkURUDcPBtPZiMGDoi1WtSirOqFVMIRChAQQABkGGljk3iJMX549JTu3dw9f+n8a7uvsyAoy3Je5jtXr2xsbTmrsvmkElVeFYZCZ2UFI4xCqpUAysq6zKq8FAuhSqEqraR13nsGwIFGFKJm2j94+ejgxT06nY2NVdfeeXP31tvtft8YQ8Bdv3WzFLNsMdZGEUo558ubGxtbg3Y7rsVCayW0EkrUspaqVlpY74EgYzHDBmNxEvcHKxey+YmQAo/P9ikj/d752x98v9FpKmWvXL0s1bkv7lybff65syqK+84obcrdvctRCKIuGUPvrRCLWpZKK++Jd95Ya53TWittuq22M36cDXcudfHBi49fHj28/cZP1tevXtm7FIRJFKdSxLfevf3g6/+bT/P56FjVRcDfANBaW2sVk6iVqGUuRGWttdYYK5UWQhRlNddGp3Hr8Oj5/ul/PxsP8e79fzke+qXGXWVNt71BGZXCeQKXdy+/9f3b//XvP5NSY8g8cVWVM0YCHgpRKVUbayhF50lZ5bXIjZVK10rVaRLU8sXDb7+dTMavHjgcT4BAKJU6PdtXSpVVzjHgGPIA/uh733v67YMXj5567yhClo05Z+1WjzqQsnbeEgAp66LM6jpTRjhnwiAenR7+6pNfHuzXUdemLcBOc/fw+KlzXAg5Gh+UZc5ZEAQxYthdarz7wXsnwyMWoFJVlo2ULBBfj6JYGw0AtcgXxbQoplpLAgAMDw/3J6Oz0VTYhtURYRwRtPMaRC2TqEscq4qFJw4ZIgswiLd2Nq/dunF6fBaEvsjHUlR1VVGg3oPzVshaiEqKyjoTJWleZMbZje1OZ6VdavvwyRPKkv8H5nQSpMdt/MkAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 3 Predictions: ['bird: 33.3%', 'dog: 18.5%', 'cat: 13.7%']\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAACAAAAAgCAIAAAD8GO2jAAAJa0lEQVR4nAXBSW+d13kA4Pd9zznfdL87krwcRJEiJVEMJXl2g8IxkiYpmhQxCsQo0EWHX9Bf0k0X7a5Fuw66aTdtiiyaRE4DRyksW6YmSyIpzpfkHb/hjG+fBz/9h58IUoSoADCw8S4gEHkkYimFFEpIBIAAhBIRAQGC86wJUVKEgcCz9zawF0IRIgdt7cx77b213khjykgqIVQsYiFE4GCBAQkQmBmAiCmhVMmYIAaWgC7IsnJXHoxnAEYGDMQCEZARgJAdsHPOecsQJATgYIlQYiwQYqUQEQkQoB1lm0n3Wsha0FbRAmOCSZfrAcFwzPrQDs7dxQjLCr0PjpEZA4BAAEDyDB5YkJBxIoQARtJoJYmAIIRcjvrbJm8OL+fTAMUpJbb7wR+p9pJQ6uLX/2pOv5yXQp9cLrTa2cbOpZh+Di9qNIzKcwBkEChiycErQhnHjMSBUQMbMkRhHlv3q3z48OHhuHj/nbsx+rS7ki+uqiyvxldcnrEZOyOPjg7Gu9X9Gjtx+nZr4VVyMcxMYPaBgwBEJAZmJg41sA3eOl8bVy3W2b3Tdnj2YvfLr7ixJFTGwevJmaun3ujRN7/To6MozUnGTnX2z8aDo8PpxQienvceuuuDpvTCB++8tV5rqytjZaXLKEqMrZy313muf6SuhidUnrrGyq3N9eHp68Hp0Y1tsViPDeHszaOjN2+yrNFbXs97y1f29Wx0GSCdTovBzLSK7p33lp83h1M/wmCtNjrEEl3wIiBTn7vL++nhN0cLcx09Gi2v3f/tV3vn+8/bouosToCdvjy62Hu8+/z11Ea9Jbu+2t9cma9N5SfFRIerWTkry5jCzs7KbtuMuRJBAiupok4QskfJrZP88uhCg0QVSYRiNP7s1ShYsdFbUt110ZgPYVQV1VERfzNCd3myeTx6pw91gZUNhScNCl1xuP9i1Uyuvbs8zqUQKg4ovVBM0Zpuz548nVQoOyvGc4S+n+pb/c5Yi50719c++pSaSwRxurgTZRcN60ujN/qJCEPthAOuXQhIzAZcZcbn3fNGmoOTTlojgzcNTPpns0NdRKLFSFPNHZS6HH18c02mC5vf//Nk6dblqIxJbX780+9NBydv9pTKF1L96lXtqesChBAUcSQxFhH4IC4m3Y3mGGcIlnxwAMaywRA4BGQbxzHky5ejQja6t37wF83b7x8Pxqenl9M6yP7G1nc/vbezOZc6cLXHRBIB+ERRQlYSIinPWFGYkPfsA5dSu/GEZqfNuThWaEARNxKpxJzS9cLdj5LNdx8/29v7+qu40X3//s1saa5x88N0fnnh8uD1b/6DDx8lAj1zFEkKBE4CsFJo52nKg6BHEIy0rmSgfanursxFxzaKkyyWFGBh+XrJ8n+fHPznf382HZxglF+eH//0x99pdntZu09Jiov3BD1WiAkSSxGSJGirwOdtddIt62rg6hkgy8BKG3HGk8XFRrdqciNLFdnClbV/eVH89r/+6cmDn4PVMspO+IfTyy2VpOx5evTs5OXvIwGKCAJ5KQBj46a59G5OHsaVrcE7RIFkJtoUrq6r17iv5kUnT2ICCJ4RXKJ23v1ACPK2yITuxm42Geq6nBb1Lz97tPvslRIkpVCSYkVpItspR3l4Mc9TR2Bj8CrYiMank3JSeyuufPkqOU1zqQgEBEI38adrW3du377TzLJWlirWdXk5HZ8/fvTo4WcPvGQkRIA4Us1U5glGMZytNwZtYkhJtCV1BGaSQEpSUiWCkgNbvpNDPHMC2VM1Cq/H08trnV7a65IU9cXx2bMv9/be7H7xdHVeoITa+hyEIIiUyBudJ2py0NUoOM2aMo29zY2bytZCK0pIiqAoytNuM24Ee8Xspg0z9Be7L/a/nXZm3UWL3Al5eup+t/+gM5/c3coPLuvB1awBHLyF4KCx+uvDjAe/WbjWjEQqmQxHFCKp4lgIRlBCyE6rI7Su9GycmdkNMtS+Oiv+aummIawB27IxjNRb75nVFSPJr1N7n+XYx73ADvyw5OPkrRvxozjJEVNnSgWCKZZSJYJcKlNBSrP/wh2a/swukGy1xYl5ORju5vPfTtsVRBjqgZqu315tifOIYKHd6rVKK+fyuZt7A/8vX3YKz005hzKgFJKixAUFIJP2fAiFQAPsKqwOmjLkMpLs68bV0bpq9P+etz/x5Yeti7hVzS/nvQyVayGHKE6zVphqGrjkZ7/41fNq++5C3B22q4VjDxaD0IBBkUzi1LjgkRG8iBxEmXd5XZijLwydDzfayTOmf766/kWgP12YrYEjqQgIvdEhCiJ7eFD82/+9efn8ZGur+73te8VLkZ5Hk741tmJfoGCJTnurDUopJHsEb+upvHhQ1i/Ld/9g6/s/+WRg8n/82f7DF9XzFxerzdn2arbclrNSn0+ujqbi2WFRa7/ea//1D+7cun3rV69+Ls6zWWZqBFNPgceyrkpvXRBoo2Bli+v47MFw9vR0a3Nr5623b2ysbZTT7p8t/vsvzz5/7L/en3316pwEMiJ6H0nsd/PvvL/2yUd3b9zcquuyF7t6ahbG6rTnLUWAufTAiBp4pn3DVI2JvjOuPt/+4+vv9T7Yvv+OMToK5lZf/O2frBy9jU/2G3vHw9GkJICluXR7be7mtd5cq5s0u0qgbLavX1ueHs/S0p+1nIgzUi3JUer53AR3ad8+99+VZm/j4+GHd/7mD5d/FMex0bUEkMgK6tU2r3ywqaRCqyFAlMRCoDclBydChcGgjOaX19X00FW1m+oihQgjuXf6rZn5aGx3KvGtZvz7a63/8SJenrvXyJtOV0pKp7UwlfAaIFTWYNSJ40wQCkEQLIFhYGAP3oCrIzBZlnTZNVFO0Hpdyadf/yWrhJLZ/Nwv1roPJA05NBKQzhglI28rXZchhFgoodKgul52AgZBzOwDxjJW6EsEZqcpypNmzyQNWUxbKYya3oRYCrebJnu9uWd5dwq2IioiEUUUMwMDUPCxQAgeSXDU8arvgwI2wCYECyKSUYuEDKbwpvCUyKV7fPKNtQcAQZM0rOTi2t/FUZBs6nPvG9DMIUkySY0A4Kwj6yICZu+CcmnfcwqudGZKkiU4BA4OkAQD1JMLOx2HuKNZGKuLoqqkc17LLNMQsL6qpGArdUFRGnwAYa3VuqRqHLiO2LHMjFfWG9TjUA2dJBReypTZBYHgfFVMS11hOpsV49FkUjdtVVpnq/8HSx6mnJgbezwAAAAASUVORK5CYII=",
      "text/plain": [
       "<PIL.Image.Image image mode=RGB size=32x32>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 3 Predictions: ['dog: 28.1%', 'cat: 22.6%', 'bird: 15.9%']\n"
     ]
    }
   ],
   "source": [
    "from corenet.options.opts import get_training_arguments\n",
    "from corenet.modeling import get_model\n",
    "from PIL import Image\n",
    "import torch\n",
    "from torchvision.transforms import Compose, Resize, PILToTensor, CenterCrop\n",
    "from torchvision.transforms import ToPILImage\n",
    "from corenet.data.datasets.classification.playground_dataset import Cifar10\n",
    "\n",
    "config_file = \"projects/playground_cifar10/classification/cifar10.yaml\"\n",
    "pretrained_weights = \"results/run_1/checkpoint_best.pt\"\n",
    "\n",
    "opts = get_training_arguments(\n",
    "    args=[\n",
    "        \"--common.config-file\",\n",
    "        config_file,\n",
    "        \"--model.classification.pretrained\",\n",
    "        pretrained_weights,\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Load the model\n",
    "model = get_model(opts)\n",
    "model.eval()\n",
    "\n",
    "for image_path in [\"assets/cat.jpeg\", \"assets/dog.jpeg\"]:\n",
    "    image = Image.open(image_path).convert(\"RGB\")\n",
    "    img_transforms = Compose([CenterCrop(600), Resize(size=(32, 32)), PILToTensor()])\n",
    "\n",
    "    # Transform the image, normalize between 0 and 1\n",
    "    input_tensor = img_transforms(image)\n",
    "\n",
    "    # Show the transformed image\n",
    "    ToPILImage()(input_tensor).show()\n",
    "\n",
    "    input_tensor = input_tensor.to(torch.float).div(255.0)\n",
    "\n",
    "    # add dummy batch dimension\n",
    "    input_tensor = input_tensor[None, ...]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        logits = model(input_tensor)[0]\n",
    "        probs = torch.softmax(logits, dim=-1)\n",
    "        predictions = sorted(zip(probs.tolist(), Cifar10.CLASS_NAMES), reverse=True)\n",
    "        print(\n",
    "            \"Top 3 Predictions:\",\n",
    "            [f\"{cls}: {prob:.1%}\" for prob, cls in predictions[:3]],\n",
    "        )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
